#!/usr/bin/env python3
import os
import sys
import argparse
import numpy as np
import re
from glob import glob
import time
import importlib
import tensorflow as tf
import pickle

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from datasets.dataset_reader_argoverse import read_data_val
from datasets.helper import get_lane_direction
from traffic_evaluation_helper import TrafficErrors


def match_pr_gt(pr, gt, pr_id, gt_id, tensor=True):
    
    if isinstance(pr_id, tf.Tensor):
        pr_id = pr_id.numpy
    if isinstance(gt_id, tf.Tensor):
        gt_id = gt_id.numpy
    
    pr_in = np.isin(pr_id, np.intersect1d(pr_id, gt_id))
    gt_in = np.isin(gt_id, np.intersect1d(pr_id, gt_id))
    
    if isinstance(pr, tf.Tensor):
        gt_match = pr.numpy()
    else:
        gt_match = pr
    
    adj_term = np.mean(np.linalg.norm(gt_match[pr_in] - gt[gt_in]))
    gt_match += adj_term
    gt_match[pr_in] = gt[gt_in]
    
    if tensor:
        gt_match = tf.convert_to_tensor(gt_match, dtype=np.float32)
    
    return gt_match

def get_agent(pr, gt, pr_id, gt_id, agent_id, tensor=True):
    if isinstance(pr, tf.Tensor):
        pr_array = pr.numpy()
    if isinstance(gt, tf.Tensor):
        gt_array = gt.numpy()
    else:
        gt_array = gt
        
    pr_agent = pr_array[pr_id == agent_id]
    gt_agent = gt_array[gt_id == agent_id]
    
    return pr_agent, gt_agent


def evaluate(model, val_dataset, frame_skip, fluid_errors=None, am=None):
    print('evaluating.. ', end='', flush=True)

    if fluid_errors is None:
        fluid_errors = TrafficErrors()
    
    if am is None:
        from argoverse.map_representation.map_api import ArgoverseMap
        am = ArgoverseMap()
        
    count = 0

    skip = frame_skip
  
    last_scene_id = 0
    frames = []
    predictions = {}
    for data in val_dataset:
        if data['frame_id'][0] == 0:
            frames = []
        if data['frame_id'][0] % skip <= 29:
            frames.append(data)
        if data['frame_id'][0] % skip == 29:
            
            # if np.random.rand(1)[0] > 0.5:
            #     frames = []
            #     continue
            if len(set([frames[i]['scene_idx'][0] for i in range(29)])) == 1:
                scene_id = frames[0]['scene_idx'][0]
                
                pred = []
                
                if count % 250 == 0:
                    print('{}({})'.format(count + 1, scene_id), end=' ', flush=True)
                count += 1
                
                if last_scene_id != scene_id:
                    last_scene_id = scene_id

                lane = frames[0]['lane'][0]
                lane_normals = frames[0]['lane_norm'][0]
                agent_id = frames[0]['agent_id'][0]
                city = frames[0]['city'][0]
                
                inputs = (frames[0]['pos0'][0], frames[0]['vel0'][0], 
                          None, lane, lane_normals)
                pr_pos1, pr_vel1 = model(inputs)
                # gt_pos1 = match_pr_gt(pr_pos1, frames[1]['pos0'][0], 
                #                       frames[0]['track_id0'][0], 
                #                       frames[1]['track_id0'][0])
                pr_agent, gt_agent = get_agent(pr_pos1, frames[1]['pos0'][0],
                                               frames[0]['track_id0'][0], 
                                               frames[1]['track_id0'][0], 
                                               agent_id)
                                               
                fluid_errors.add_errors(scene_id, frames[0]['frame_id'][0], 
                                        frames[1]['frame_id'][0], pr_agent, 
                                        gt_agent)
                pred.append(pr_agent)
                
                for i in range(1, 29):
                    # direction = get_lane_direction(pr_pos1, city, am)
                    inputs = (pr_pos1, pr_vel1, None, lane, lane_normals)
                    pr_pos2, pr_vel2 = model(inputs)
                    # gt_pos2 = match_pr_gt(pr_pos2, frames[i+1]['pos0'][0], 
                    #                   frames[0]['track_id0'][0], 
                    #                   frames[i+1]['track_id0'][0])
                    pr_agent, gt_agent = get_agent(pr_pos2, frames[i+1]['pos0'][0],
                                                   frames[0]['track_id0'][0], 
                                                   frames[i+1]['track_id0'][0], 
                                                   agent_id)

                    fluid_errors.add_errors(scene_id, frames[i]['frame_id'][0], 
                                            frames[i+1]['frame_id'][0], pr_agent, 
                                            gt_agent)
                    pred.append(pr_agent)
                    pr_pos1, pr_vel1 = pr_pos2, pr_vel2
                
                if count % 5000 == 0:
                    break
            
            predictions[scene_id] = pred
            frames = []


    with open('prediction_1t_map_5k.pickle', 'wb') as f:
        pickle.dump(predictions, f)
    
    result = {}
    de = {}
    
    for k, v in fluid_errors.errors.items():
        de[k[0]] = de.get(k[0], [])
        de[k[0]].append(v['mean'])
        
    ade = []
    de1s = []
    de2s = []
    de3s = []
    for k, v in de.items():
        ade.append(np.mean(v))
        de1s.append(v[9])
        de2s.append(v[19])
        de3s.append(v[-1])
    
    result['ADE'] = np.mean(ade)
    result['ADE_std'] = np.std(ade)
    result['DE@1s'] = np.mean(de1s)
    result['DE@1s_std'] = np.std(de1s)
    result['DE@2s'] = np.mean(de2s)
    result['DE@2s_std'] = np.std(de2s)
    result['DE@3s'] = np.mean(de3s)
    result['DE@3s_std'] = np.std(de3s)

    print(result)
    print('done')

    return result


